{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vD3L4qeREvg"
      },
      "source": [
        "##### Copyright 2024 The AI Edge Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qLCxmWRyRMZE"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8LYgHRFPRpS1"
      },
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4k5PoHrgJQOU"
      },
      "source": [
        "# Jax Model Conversion For TFLite\n",
        "## Overview\n",
        "Note: This API is new and we recommend using via pip install tf-nightly. Also, the API is still experimental and subject to changes.\n",
        "\n",
        "This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i8cfOBcjSByO"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lq-T8XZMJ-zv"
      },
      "source": [
        "## Prerequisites\n",
        "It's recommended to try this feature with the newest TensorFlow nightly pip build."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EV04hKdrnE4f"
      },
      "outputs": [],
      "source": [
        "!pip install tf-nightly --upgrade\n",
        "!pip install jax --upgrade"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vsilblGuGQa2"
      },
      "outputs": [],
      "source": [
        "# Make sure your JAX version is at least 0.4.20 or above.\n",
        "import jax\n",
        "jax.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PJeQhMUwH0oX"
      },
      "outputs": [],
      "source": [
        "!pip install orbax-export --upgrade"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j9_CVA0THQNc"
      },
      "outputs": [],
      "source": [
        "from orbax.export import ExportManager\n",
        "from orbax.export import JaxModule\n",
        "from orbax.export import ServingConfig"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAeY43k9KM55"
      },
      "source": [
        "## Data Preparation\n",
        "Download the MNIST data with Keras dataset and pre-process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qSOPSZJn1_Tj"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import functools\n",
        "\n",
        "import time\n",
        "import itertools\n",
        "\n",
        "import numpy.random as npr\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import jit, grad, random\n",
        "from jax.example_libraries import optimizers\n",
        "from jax.example_libraries import stax\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hdJIt3Da2Qn1"
      },
      "outputs": [],
      "source": [
        "def _one_hot(x, k, dtype=np.float32):\n",
        "  \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
        "  return np.array(x[:, None] == np.arange(k), dtype)\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
        "train_images, test_images = train_images / 255.0, test_images / 255.0\n",
        "train_images = train_images.astype(np.float32)\n",
        "test_images = test_images.astype(np.float32)\n",
        "\n",
        "train_labels = _one_hot(train_labels, 10)\n",
        "test_labels = _one_hot(test_labels, 10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0eFhx85YKlEY"
      },
      "source": [
        "## Build the MNIST model with Jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mi3TKB9nnQdK"
      },
      "outputs": [],
      "source": [
        "def loss(params, batch):\n",
        "  inputs, targets = batch\n",
        "  preds = predict(params, inputs)\n",
        "  return -jnp.mean(jnp.sum(preds * targets, axis=1))\n",
        "\n",
        "def accuracy(params, batch):\n",
        "  inputs, targets = batch\n",
        "  target_class = jnp.argmax(targets, axis=1)\n",
        "  predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n",
        "  return jnp.mean(predicted_class == target_class)\n",
        "\n",
        "init_random_params, predict = stax.serial(\n",
        "    stax.Flatten,\n",
        "    stax.Dense(1024), stax.Relu,\n",
        "    stax.Dense(1024), stax.Relu,\n",
        "    stax.Dense(10), stax.LogSoftmax)\n",
        "\n",
        "rng = random.PRNGKey(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bRtnOBdJLd63"
      },
      "source": [
        "## Train \u0026 Evaluate the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SWbYRyj7LYZt"
      },
      "outputs": [],
      "source": [
        "step_size = 0.001\n",
        "num_epochs = 10\n",
        "batch_size = 128\n",
        "momentum_mass = 0.9\n",
        "\n",
        "\n",
        "num_train = train_images.shape[0]\n",
        "num_complete_batches, leftover = divmod(num_train, batch_size)\n",
        "num_batches = num_complete_batches + bool(leftover)\n",
        "\n",
        "def data_stream():\n",
        "  rng = npr.RandomState(0)\n",
        "  while True:\n",
        "    perm = rng.permutation(num_train)\n",
        "    for i in range(num_batches):\n",
        "      batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n",
        "      yield train_images[batch_idx], train_labels[batch_idx]\n",
        "batches = data_stream()\n",
        "\n",
        "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n",
        "\n",
        "@jit\n",
        "def update(i, opt_state, batch):\n",
        "  params = get_params(opt_state)\n",
        "  return opt_update(i, grad(loss)(params, batch), opt_state)\n",
        "\n",
        "_, init_params = init_random_params(rng, (-1, 28 * 28))\n",
        "opt_state = opt_init(init_params)\n",
        "itercount = itertools.count()\n",
        "\n",
        "print(\"\\nStarting training...\")\n",
        "for epoch in range(num_epochs):\n",
        "  start_time = time.time()\n",
        "  for _ in range(num_batches):\n",
        "    opt_state = update(next(itercount), opt_state, next(batches))\n",
        "  epoch_time = time.time() - start_time\n",
        "\n",
        "  params = get_params(opt_state)\n",
        "  train_acc = accuracy(params, (train_images, train_labels))\n",
        "  test_acc = accuracy(params, (test_images, test_labels))\n",
        "  print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n",
        "  print(\"Training set accuracy {}\".format(train_acc))\n",
        "  print(\"Test set accuracy {}\".format(test_acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Y1OZBhfQhOj"
      },
      "source": [
        "## Convert to TFLite model.\n",
        "Note here, we\n",
        "1. Export the `JAX` model to `TF SavedModel` using `orbax`.\n",
        "2. Call TFLite converter API to convert the `TF SavedModel` to `.tflite` model:\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6pcqKZqdNTmn"
      },
      "outputs": [],
      "source": [
        "jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')\n",
        "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
        "    [\n",
        "        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n",
        "            tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n",
        "        )\n",
        "    ]\n",
        ")\n",
        "\n",
        "tflite_model = converter.convert()\n",
        "with open('jax_mnist.tflite', 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sqEhzaJPSPS1"
      },
      "source": [
        "## Check the Converted TFLite Model\n",
        "Compare the converted model's results with the Jax model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "acj2AYzjSlaY"
      },
      "outputs": [],
      "source": [
        "serving_func = functools.partial(predict, params)\n",
        "expected = serving_func(train_images[0:1])\n",
        "\n",
        "# Run the model with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
        "interpreter.invoke()\n",
        "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "# Assert if the result of TFLite model is consistent with the JAX model.\n",
        "np.testing.assert_almost_equal(expected, result, 1e-5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qy9Gp4H2SjBL"
      },
      "source": [
        "## Optimize the Model\n",
        "We will provide a `representative_dataset` to do post-training quantiztion to optimize the model.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KI0rLV-Meg-2"
      },
      "outputs": [],
      "source": [
        "def representative_dataset():\n",
        "  for i in range(1000):\n",
        "    x = train_images[i:i+1]\n",
        "    yield [x]\n",
        "x_input = jnp.zeros((1, 28, 28))\n",
        "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n",
        "    [serving_func], [[('x', x_input)]])\n",
        "tflite_model = converter.convert()\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "converter.representative_dataset = representative_dataset\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
        "tflite_quant_model = converter.convert()\n",
        "with open('jax_mnist_quant.tflite', 'wb') as f:\n",
        "  f.write(tflite_quant_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "15xQR3JZS8TV"
      },
      "source": [
        "## Evaluate the Optimized Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X3oOm0OaevD6"
      },
      "outputs": [],
      "source": [
        "expected = serving_func(train_images[0:1])\n",
        "\n",
        "# Run the model with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
        "interpreter.invoke()\n",
        "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "# Assert if the result of TFLite model is consistent with the Jax model.\n",
        "np.testing.assert_almost_equal(expected, result, 1e-5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QqHXCNa3myor"
      },
      "source": [
        "## Compare the Quantized Model size\n",
        "We should be able to see the quantized model is four times smaller than the original model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "imFPw007juVG"
      },
      "outputs": [],
      "source": [
        "!du -h jax_mnist.tflite\n",
        "!du -h jax_mnist_quant.tflite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N5WdO18wNfyn"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1UlzbQspn2an2kzlLWZBWhP_JEqvShxmi",
          "timestamp": 1705015454450
        },
        {
          "file_id": "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/overview.ipynb",
          "timestamp": 1698963811786
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
